-
Notifications
You must be signed in to change notification settings - Fork 617
[TorchToLinalg] Direct lowering from Torch to Linalg for torch.aten.convolution_backward #4384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
4f1cb20 to
8e2b616
Compare
|
@zjgarvey hey! May I ask you to take a look when you're available? Thank you in advance for the review. |
zjgarvey
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! This is an excellent start.
We need to keep the existing decomposition for other backends. I have a few other comments for you to look at, but that's the biggest blocker right now.
| rewriter, loc, op.getResultTypes()[2], gradOutput, gradIntList, | ||
| cstFalse, cstNone); | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should keep the decomposition, E.g., TOSA and StableHLO still rely on this pattern. The purpose of the backend_legal_ops option in torch-decompose-complex-ops is specifically to prevent selected decomposition patterns.
| SmallVector<int64_t> weightFlipDims; | ||
| weightFlipDims.reserve(numSpatialDims); | ||
| for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i) | ||
| weightFlipDims.push_back(spatialStartDimIdx + i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the weight shape is static at index i, and the dim size is 1 there, don't add to the flip. We definitely see a lot of 1x1 filter convs and the noop flip doesn't get folded easily IIRC.
| createZeroInitTensor(rewriter, loc, sizes, gradOutputDTy); | ||
| gradOutputSliced = tensor::InsertSliceOp::create( | ||
| rewriter, loc, | ||
| torch_to_linalg::removeSizeInformation(rewriter, loc, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove the size info?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also maybe "sliced" is a misleading name. Scattered? Or something generic like "Modified" since you are just padding when stride == 1.
| createZeroInitTensor(rewriter, loc, gradWeightSizes, weightDTy); | ||
| SmallVector<ReassociationIndices> gradWeightCollapseIndices; | ||
| if (isGroupedConvBwd) { | ||
| auto gradWeightInitExpanded = expandGroups(gradWeightInit, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the init just be made on the expanded shape here (instead of expanding the init)? This probably gets folded, but I think it would be better to generate simpler IR when possible.
| // `c` is the input channel dimension, `f` is the output channel | ||
| // dimension, `o` is the input spatial dimension, `k` is the kernel | ||
| // dimension, `d0` is dilation. `x` is the input tensor, `dLdy` is the | ||
| // gradient of the output tensor. `dLdx` is the data-gradient tensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be good to mention that dLdy is the stride/padding modified grad output tensor here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And that w is flipped along spatial dims.
| } | ||
|
|
||
| static linalg::GenericOp | ||
| createGenericOp(OpBuilder &b, Location loc, Value in0, Value in1, Value out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There might be a util for this already like "createReductionGeneric` or something. In any case, might be good to call this something a little more specific (pun intended).
| if (!isGrouped) { | ||
| if (numSpatialDims == 1) { | ||
| AffineExpr n, c, o, f, k; | ||
| bindDims(context, n, c, o, f, k); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| SmallVector<AffineExpr> goExprs = {n, f, d0 * k + o}; | ||
| SmallVector<AffineExpr> weiExprs = {f, c, k}; | ||
| SmallVector<AffineExpr> outExprs = {n, c, o}; | ||
| indexingMaps = {AffineMap::get(5, 0, goExprs, context), | ||
| AffineMap::get(5, 0, weiExprs, context), | ||
| AffineMap::get(5, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::reduction, IT::reduction}; | ||
| } else if (numSpatialDims == 2) { | ||
| AffineExpr n, c, oh, ow, f, kh, kw; | ||
| bindDims(context, n, c, oh, ow, f, kh, kw); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| SmallVector<AffineExpr> goExprs = {n, f, d0 * kh + oh, d1 * kw + ow}; | ||
| SmallVector<AffineExpr> weiExprs = {f, c, kh, kw}; | ||
| SmallVector<AffineExpr> outExprs = {n, c, oh, ow}; | ||
| indexingMaps = {AffineMap::get(7, 0, goExprs, context), | ||
| AffineMap::get(7, 0, weiExprs, context), | ||
| AffineMap::get(7, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::reduction, IT::reduction, | ||
| IT::reduction}; | ||
| } else { | ||
| AffineExpr n, c, od, oh, ow, f, kd, kh, kw; | ||
| bindDims(context, n, c, od, oh, ow, f, kd, kh, kw); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); | ||
| SmallVector<AffineExpr> goExprs = {n, f, d0 * kd + od, d1 * kh + oh, | ||
| d2 * kw + ow}; | ||
| SmallVector<AffineExpr> weiExprs = {f, c, kd, kh, kw}; | ||
| SmallVector<AffineExpr> outExprs = {n, c, od, oh, ow}; | ||
| indexingMaps = {AffineMap::get(9, 0, goExprs, context), | ||
| AffineMap::get(9, 0, weiExprs, context), | ||
| AffineMap::get(9, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::reduction, | ||
| IT::reduction, IT::reduction, IT::reduction}; | ||
| } | ||
| } else { | ||
| if (numSpatialDims == 1) { | ||
| AffineExpr n, g, cg, o, fg, k; | ||
| bindDims(context, n, g, cg, o, fg, k); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, d0 * k + o}; | ||
| SmallVector<AffineExpr> weiExprs = {g, fg, cg, k}; | ||
| SmallVector<AffineExpr> outExprs = {n, g, cg, o}; | ||
| indexingMaps = {AffineMap::get(6, 0, goExprs, context), | ||
| AffineMap::get(6, 0, weiExprs, context), | ||
| AffineMap::get(6, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::reduction, IT::reduction}; | ||
| } else if (numSpatialDims == 2) { | ||
| AffineExpr n, g, cg, oh, ow, fg, kh, kw; | ||
| bindDims(context, n, g, cg, oh, ow, fg, kh, kw); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, d0 * kh + oh, | ||
| d1 * kw + ow}; | ||
| SmallVector<AffineExpr> weiExprs = {g, fg, cg, kh, kw}; | ||
| SmallVector<AffineExpr> outExprs = {n, g, cg, oh, ow}; | ||
| indexingMaps = {AffineMap::get(8, 0, goExprs, context), | ||
| AffineMap::get(8, 0, weiExprs, context), | ||
| AffineMap::get(8, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::reduction, | ||
| IT::reduction, IT::reduction}; | ||
| } else { | ||
| AffineExpr n, g, cg, od, oh, ow, fg, kd, kh, kw; | ||
| bindDims(context, n, g, cg, od, oh, ow, fg, kd, kh, kw); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); | ||
| SmallVector<AffineExpr> goExprs = { | ||
| n, g, fg, d0 * kd + od, d1 * kh + oh, d2 * kw + ow}; | ||
| SmallVector<AffineExpr> weiExprs = {g, fg, cg, kd, kh, kw}; | ||
| SmallVector<AffineExpr> outExprs = {n, g, cg, od, oh, ow}; | ||
| indexingMaps = {AffineMap::get(10, 0, goExprs, context), | ||
| AffineMap::get(10, 0, weiExprs, context), | ||
| AffineMap::get(10, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::parallel, | ||
| IT::reduction, IT::reduction, IT::reduction, | ||
| IT::reduction}; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| static void initIndexingMapsAndIteratorTypesForWeightBwd( | ||
| OpBuilder &rewriter, MLIRContext *context, bool isGrouped, | ||
| int numSpatialDims, const SmallVector<int64_t> &strideInts, | ||
| const SmallVector<int64_t> &dilationInts, | ||
| SmallVector<AffineMap> &indexingMaps, SmallVector<IT> &iteratorTypes) { | ||
| // To calculate convolution backward-weight, we use generic operation. | ||
| // The generic operation is a generalization of the convolution operation | ||
| // that can handle any number of spatial dimensions. | ||
| // The generic operation is defined as follows: | ||
| // ``` | ||
| // dLdw[f, g, c, k] = sum(x[n, g, c, d0 * k + s0 * o] * dLdy[n, g, f, o] | ||
| // for n in range(batch_size) for o in range(output_spatial_dims)) | ||
| // ``` | ||
| // where `n` is the batch dimension, `g` is the group dimension, | ||
| // `c` is the input channel dimension, `f` is the output channel | ||
| // dimension, `o` is the output spatial dimension, `k` is the kernel | ||
| // dimension, `d0` is dilation and `s0` is stride. `x` is the input | ||
| // tensor, `dLdy` is the gradient of the output tensor. `dLdw` is the | ||
| // weight-gradient tensor. | ||
| if (!isGrouped) { | ||
| if (numSpatialDims == 1) { | ||
| AffineExpr f, c, k, n, o; | ||
| bindDims(context, f, c, k, n, o); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| SmallVector<AffineExpr> inExprs = {n, c, d0 * k + s0 * o}; | ||
| SmallVector<AffineExpr> goExprs = {n, f, o}; | ||
| SmallVector<AffineExpr> outExprs = {f, c, k}; | ||
| indexingMaps = {AffineMap::get(5, 0, inExprs, context), | ||
| AffineMap::get(5, 0, goExprs, context), | ||
| AffineMap::get(5, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::reduction, IT::reduction}; | ||
| } else if (numSpatialDims == 2) { | ||
| AffineExpr f, c, kh, kw, n, oh, ow; | ||
| bindDims(context, f, c, kh, kw, n, oh, ow); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| SmallVector<AffineExpr> inExprs = {n, c, d0 * kh + s0 * oh, | ||
| d1 * kw + s1 * ow}; | ||
| SmallVector<AffineExpr> goExprs = {n, f, oh, ow}; | ||
| SmallVector<AffineExpr> outExprs = {f, c, kh, kw}; | ||
| indexingMaps = {AffineMap::get(7, 0, inExprs, context), | ||
| AffineMap::get(7, 0, goExprs, context), | ||
| AffineMap::get(7, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::reduction, IT::reduction, | ||
| IT::reduction}; | ||
| } else { | ||
| AffineExpr f, c, kd, kh, kw, n, od, oh, ow; | ||
| bindDims(context, f, c, kd, kh, kw, n, od, oh, ow); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); | ||
| AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); | ||
| SmallVector<AffineExpr> inExprs = { | ||
| n, c, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow}; | ||
| SmallVector<AffineExpr> goExprs = {n, f, od, oh, ow}; | ||
| SmallVector<AffineExpr> outExprs = {f, c, kd, kh, kw}; | ||
| indexingMaps = {AffineMap::get(9, 0, inExprs, context), | ||
| AffineMap::get(9, 0, goExprs, context), | ||
| AffineMap::get(9, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::reduction, | ||
| IT::reduction, IT::reduction, IT::reduction}; | ||
| } | ||
| } else { | ||
| if (numSpatialDims == 1) { | ||
| AffineExpr g, fg, cg, k, n, o; | ||
| bindDims(context, g, fg, cg, k, n, o); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| SmallVector<AffineExpr> inExprs = {n, g, cg, d0 * k + s0 * o}; | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, o}; | ||
| SmallVector<AffineExpr> outExprs = {g, fg, cg, k}; | ||
| indexingMaps = {AffineMap::get(6, 0, inExprs, context), | ||
| AffineMap::get(6, 0, goExprs, context), | ||
| AffineMap::get(6, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::reduction, IT::reduction}; | ||
| } else if (numSpatialDims == 2) { | ||
| AffineExpr g, fg, cg, kh, kw, n, oh, ow; | ||
| bindDims(context, g, fg, cg, kh, kw, n, oh, ow); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| SmallVector<AffineExpr> inExprs = {n, g, cg, d0 * kh + s0 * oh, | ||
| d1 * kw + s1 * ow}; | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, oh, ow}; | ||
| SmallVector<AffineExpr> outExprs = {g, fg, cg, kh, kw}; | ||
| indexingMaps = {AffineMap::get(8, 0, inExprs, context), | ||
| AffineMap::get(8, 0, goExprs, context), | ||
| AffineMap::get(8, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::reduction, | ||
| IT::reduction, IT::reduction}; | ||
| } else { | ||
| AffineExpr g, fg, cg, kd, kh, kw, n, od, oh, ow; | ||
| bindDims(context, g, fg, cg, kd, kh, kw, n, od, oh, ow); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); | ||
| AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); | ||
| SmallVector<AffineExpr> inExprs = { | ||
| n, g, cg, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow}; | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, od, oh, ow}; | ||
| SmallVector<AffineExpr> outExprs = {g, fg, cg, kd, kh, kw}; | ||
| indexingMaps = {AffineMap::get(10, 0, inExprs, context), | ||
| AffineMap::get(10, 0, goExprs, context), | ||
| AffineMap::get(10, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::parallel, | ||
| IT::reduction, IT::reduction, IT::reduction, | ||
| IT::reduction}; | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There must be a better way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
E.g., you could make the AffineExprs for stride, dilation, spatial dims, etc. SmallVector<AffineExpr>. I don't even think there need to be conditionals on anything other than like:
SmallVector<AffineExpr> lhsExprs = isGrouped ? {n, g, c} : {n, c};
// loop over spatial dims and add expressions...Everything else can be like:
int64_t numIterators = 3; // batch, parallel channel, reduction channel
numIterators += static_cast<int64_t>(isGrouped);
numIterators += numSpatialDims*2 // parallel spatial dims, reduction spatial dims
indexingMaps = {
AffineMap::get(numIterators, lhsExprs, context),
AffineMap::get(numIterators, rhsExprs, context),
AffineMap::get(numIterators, outExprs, context)
};
Description:
torch.aten.convolution_backwardfrom Torch to Linalg. Enabled this pass by default. The pass generateslinalg.genericops instead oflinalg.conv_<>for better lowering.DecomposeAtenConvolutionBackwardOpfromTorch/Transforms/DecomposeComplexOps.cpp.convolution_backward.mlir. Also added more test cases for better test coverage.Issue: